{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e70a58c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "import torch\n",
    "from datasets import load_dataset, load_from_disk\n",
    "import torch.cuda as cuda\n",
    "\n",
    "#加载分词器\n",
    "tokenizer = AutoTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext')\n",
    "\n",
    "device = torch.device('cuda' if cuda.is_available() else 'cpu')\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0567ccd4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at d:\\暂存\\ResumeSystem\\ner\\xt\\train\\cache-343894c14063acb7.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19999\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(['2',\n",
       "  '0',\n",
       "  '0',\n",
       "  '9',\n",
       "  '年',\n",
       "  '：',\n",
       "  '李',\n",
       "  '民',\n",
       "  '基',\n",
       "  '《',\n",
       "  'E',\n",
       "  't',\n",
       "  'e',\n",
       "  'r',\n",
       "  'n',\n",
       "  'a',\n",
       "  'l',\n",
       "  '#',\n",
       "  'S',\n",
       "  'u',\n",
       "  'm',\n",
       "  'm',\n",
       "  'e',\n",
       "  'r',\n",
       "  '》'],\n",
       " [0, 0, 0, 0, 0, 0, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Dataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, split):\n",
    "        #离线加载数据集\n",
    "        dataset = load_from_disk(dataset_path='./msra')[split]\n",
    "\n",
    "        #过滤掉太长的句子\n",
    "        def f(data):\n",
    "            return len(data['tokens']) <= 512 - 2\n",
    "\n",
    "        dataset = dataset.filter(f)\n",
    "\n",
    "        self.dataset = dataset\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        tokens = self.dataset[i]['tokens']\n",
    "        labels = self.dataset[i]['ner_tags']\n",
    "\n",
    "        return tokens, labels\n",
    "\n",
    "\n",
    "dataset = Dataset('train')\n",
    "print(len(dataset))\n",
    "dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e59695a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#数据整理函数\n",
    "def collate_fn(data):\n",
    "    tokens = [i[0] for i in data]\n",
    "    labels = [i[1] for i in data]\n",
    "\n",
    "    inputs = tokenizer.batch_encode_plus(tokens,\n",
    "                                         truncation=True,\n",
    "                                         padding=True,\n",
    "                                         return_tensors='pt',\n",
    "                                         is_split_into_words=True)\n",
    "\n",
    "    lens = inputs['input_ids'].shape[1]\n",
    "\n",
    "    for i in range(len(labels)):\n",
    "        labels[i] = [7] + labels[i]\n",
    "        labels[i] += [7] * lens\n",
    "        labels[i] = labels[i][:lens]\n",
    "\n",
    "    # 将 inputs 和 labels 移动到设备上\n",
    "    inputs = {key: value.to(device) for key, value in inputs.items()}\n",
    "    labels = torch.LongTensor(labels).to(device)\n",
    "\n",
    "    return inputs, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e5b2fcf2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
     ]
    }
   ],
   "source": [
    "# 数据加载器\n",
    "loader = torch.utils.data.DataLoader(dataset=dataset,\n",
    "                                     batch_size=8,\n",
    "                                     collate_fn=collate_fn,\n",
    "                                     shuffle=True,\n",
    "                                     drop_last=True)\n",
    "\n",
    "for i, (inputs, labels) in enumerate(loader):\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    labels = labels.to(device)\n",
    "    break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f90b5b5a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'inputs' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[4], line 9\u001b[0m\n\u001b[0;32m      5\u001b[0m pretrained \u001b[39m=\u001b[39m pretrained\u001b[39m.\u001b[39mto(device)\n\u001b[0;32m      7\u001b[0m \u001b[39m#模型试算\u001b[39;00m\n\u001b[0;32m      8\u001b[0m \u001b[39m#[b, lens] -> [b, lens, 768]\u001b[39;00m\n\u001b[1;32m----> 9\u001b[0m pretrained(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39minputs)\u001b[39m.\u001b[39mlast_hidden_state\u001b[39m.\u001b[39mshape\n",
      "\u001b[1;31mNameError\u001b[0m: name 'inputs' is not defined"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModel\n",
    "\n",
    "#加载预训练模型\n",
    "pretrained = AutoModel.from_pretrained('hfl/chinese-roberta-wwm-ext')\n",
    "pretrained = pretrained.to(device)\n",
    "pretrained(**inputs).last_hidden_state.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3096c294",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'inputs' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[5], line 46\u001b[0m\n\u001b[0;32m     44\u001b[0m \u001b[39m# model = Model()\u001b[39;00m\n\u001b[0;32m     45\u001b[0m model \u001b[39m=\u001b[39m model\u001b[39m.\u001b[39mto(device)\n\u001b[1;32m---> 46\u001b[0m model(inputs)\u001b[39m.\u001b[39mshape\n",
      "\u001b[1;31mNameError\u001b[0m: name 'inputs' is not defined"
     ]
    }
   ],
   "source": [
    "# 定义下游模型\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.tuning = False\n",
    "        self.pretrained = None\n",
    "\n",
    "        self.rnn = torch.nn.GRU(768, 768, batch_first=True)\n",
    "        self.fc = torch.nn.Linear(768, 8)\n",
    "        \n",
    "    def forward(self, inputs):\n",
    "        if self.tuning:\n",
    "            out = self.pretrained(**inputs).last_hidden_state\n",
    "        else:\n",
    "            with torch.no_grad():\n",
    "                out = pretrained(**inputs).last_hidden_state\n",
    "\n",
    "        out, _ = self.rnn(out)\n",
    "        \n",
    "        out = self.fc(out).softmax(dim=2)\n",
    "        \n",
    "        return out\n",
    "\n",
    "    def fine_tuning(self, tuning):\n",
    "        self.tuning = tuning\n",
    "        if tuning:\n",
    "            for i in pretrained.parameters():\n",
    "                i.requires_grad = True\n",
    "\n",
    "            pretrained.train()\n",
    "            self.pretrained = pretrained\n",
    "        else:\n",
    "            for i in pretrained.parameters():\n",
    "                i.requires_grad_(False)\n",
    "\n",
    "            pretrained.eval()\n",
    "            self.pretrained = None\n",
    "\n",
    "\n",
    "model = Model()\n",
    "model = model.to(device)\n",
    "model(inputs).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "942184e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 0.3735,  0.1839, -1.2947, -0.5799,  1.4103,  1.9831,  1.1971,  0.4692],\n",
       "         [-0.4211,  0.2562,  0.1699, -0.2309,  0.3574,  0.7039,  0.0447,  1.1816],\n",
       "         [-0.0445, -0.8757, -1.3016,  1.7092, -0.9039, -1.2696,  0.9966,  0.3699],\n",
       "         [ 1.0213, -0.4037, -1.7713,  0.2936,  1.2625,  1.4149, -0.2500, -1.3706],\n",
       "         [-0.1633,  0.1126, -0.5640, -0.5190, -0.7649, -0.6759, -0.1617, -0.0638],\n",
       "         [-0.0021,  1.6459, -0.0471, -0.4606,  0.1794,  0.8733, -0.3115, -1.3731]]),\n",
       " tensor([1., 1., 1., 1., 1., 1.]))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#对计算结果和label变形,并且移除pad\n",
    "def reshape_and_remove_pad(outs, labels, attention_mask):\n",
    "    #变形,便于计算loss\n",
    "    #[b, lens, 8] -> [b*lens, 8]\n",
    "    outs = outs.reshape(-1, 8)\n",
    "    #[b, lens] -> [b*lens]\n",
    "    labels = labels.reshape(-1)\n",
    "\n",
    "    #忽略对pad的计算结果\n",
    "    #[b, lens] -> [b*lens - pad]\n",
    "    select = attention_mask.reshape(-1) == 1\n",
    "    outs = outs[select]\n",
    "    labels = labels[select]\n",
    "\n",
    "    return outs, labels\n",
    "\n",
    "\n",
    "reshape_and_remove_pad(torch.randn(2, 3, 8), torch.ones(2, 3),\n",
    "                       torch.ones(2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8dab97e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 16, 3, 16)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 获取正确数量和总数\n",
    "def get_correct_and_total_count(labels, outs):\n",
    "    #[b*lens, 8] -> [b*lens]\n",
    "    outs = outs.argmax(dim=1)\n",
    "    correct = (outs == labels).sum().item()\n",
    "    total = len(labels)\n",
    "\n",
    "    # 计算除了0以外元素的正确率\n",
    "    select = labels != 0\n",
    "    outs = outs[select]\n",
    "    labels = labels[select]\n",
    "    correct_content = (outs == labels).sum().item()\n",
    "    total_content = len(labels)\n",
    "\n",
    "    return correct, total, correct_content, total_content\n",
    "\n",
    "get_correct_and_total_count(torch.ones(16), torch.randn(16, 8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1bd44a7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AdamW\n",
    "from torch.optim.lr_scheduler import StepLR\n",
    "def train(epochs):\n",
    "    lr = 1e-5 if model.tuning else 3e-4\n",
    "    optimizer = AdamW(model.parameters(), lr=lr)\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    scheduler = StepLR(optimizer, step_size=4, gamma=0.3)\n",
    "    \n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        for step, (inputs, labels) in enumerate(loader):\n",
    "            outs = model(inputs)\n",
    "            outs, labels = reshape_and_remove_pad(outs, labels,\n",
    "                                                  inputs['attention_mask'])\n",
    "            loss = criterion(outs, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            if step % 400 == 0:\n",
    "                counts = get_correct_and_total_count(labels, outs)\n",
    "\n",
    "                accuracy = counts[0] / counts[1]\n",
    "                accuracy_content = counts[2] / counts[3]\n",
    "\n",
    "                print(epoch, step, loss.item(), accuracy, accuracy_content)\n",
    "        scheduler.step()\n",
    "        torch.save(model, 'ner5.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e5fa189c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "354.9704\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\12236\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\optimization.py:407: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0 1.4061245918273926 0.8676470588235294 0.8125\n",
      "0 400 1.3827636241912842 0.8911917098445595 0.8958333333333334\n",
      "0 800 1.336877465248108 0.9354838709677419 0.851063829787234\n",
      "0 1200 1.4235705137252808 0.8524590163934426 0.7954545454545454\n",
      "0 1600 1.3927167654037476 0.8717948717948718 0.8305084745762712\n",
      "0 2000 1.3143255710601807 0.9586919104991394 0.765625\n",
      "0 2400 1.3752126693725586 0.8958333333333334 0.8653846153846154\n",
      "1 0 1.3365840911865234 0.941908713692946 0.7894736842105263\n",
      "1 400 1.395945429801941 0.8661417322834646 0.9508196721311475\n",
      "1 800 1.3376469612121582 0.935251798561151 0.7777777777777778\n",
      "1 1200 1.3243870735168457 0.9518987341772152 0.8769230769230769\n",
      "1 1600 1.3798080682754517 0.8955223880597015 0.8620689655172413\n",
      "1 2000 1.4270416498184204 0.8473282442748091 0.847457627118644\n",
      "1 2400 1.3653578758239746 0.9078947368421053 0.8076923076923077\n",
      "2 0 1.3485286235809326 0.9254658385093167 0.6883116883116883\n",
      "2 400 1.4106417894363403 0.8541666666666666 0.7142857142857143\n",
      "2 800 1.3221021890640259 0.9577464788732394 0.8636363636363636\n",
      "2 1200 1.3328737020492554 0.9379310344827586 0.8679245283018868\n",
      "2 1600 1.287092685699463 0.9866369710467706 0.9444444444444444\n",
      "2 2000 1.2753733396530151 1.0 1.0\n",
      "2 2400 1.3435217142105103 0.9299363057324841 0.631578947368421\n",
      "3 0 1.3631733655929565 0.915 0.8873239436619719\n",
      "3 400 1.3408864736557007 0.9308176100628931 0.9298245614035088\n",
      "3 800 1.4138520956039429 0.8611111111111112 0.8\n",
      "3 1200 1.2749402523040771 1.0 1.0\n",
      "3 1600 1.3859864473342896 0.8924731182795699 0.8305084745762712\n",
      "3 2000 1.4205231666564941 0.8555555555555555 0.8076923076923077\n",
      "3 2400 1.3842880725860596 0.8913043478260869 0.7972972972972973\n",
      "4 0 1.5075132846832275 0.7627118644067796 0.6231884057971014\n",
      "4 400 1.28827702999115 0.986784140969163 0.9772727272727273\n",
      "4 800 1.5129084587097168 0.7593984962406015 0.8035714285714286\n",
      "4 1200 1.304489016532898 0.9605263157894737 0.9491525423728814\n",
      "4 1600 1.3043776750564575 0.9688581314878892 0.85\n",
      "4 2000 1.2846952676773071 0.9894179894179894 0.9649122807017544\n",
      "4 2400 1.2760257720947266 1.0 1.0\n",
      "5 0 1.3097718954086304 0.9661016949152542 0.9047619047619048\n",
      "5 400 1.275580883026123 1.0 1.0\n",
      "5 800 1.3517627716064453 0.9176470588235294 0.8936170212765957\n",
      "5 1200 1.2967264652252197 0.9775280898876404 0.875\n",
      "5 1600 1.2980386018753052 0.9761904761904762 0.9193548387096774\n",
      "5 2000 1.3147801160812378 0.9609375 0.9090909090909091\n",
      "5 2400 1.3756659030914307 0.8984375 0.9107142857142857\n",
      "6 0 1.3535027503967285 0.9203539823008849 0.96\n",
      "6 400 1.3825846910476685 0.893048128342246 0.6428571428571429\n",
      "6 800 1.3640373945236206 0.9101123595505618 0.8688524590163934\n",
      "6 1200 1.291549563407898 0.9841269841269841 0.9411764705882353\n",
      "6 1600 1.2982434034347534 0.9761904761904762 0.9473684210526315\n",
      "6 2000 1.345873236656189 0.9285714285714286 0.7283950617283951\n",
      "6 2400 1.2783414125442505 0.9956140350877193 0.9661016949152542\n",
      "7 0 1.29753577709198 0.9780701754385965 0.9019607843137255\n",
      "7 400 1.294156789779663 0.979757085020243 0.9183673469387755\n",
      "7 800 1.3195937871932983 0.9542168674698795 0.703125\n",
      "7 1200 1.3005859851837158 0.9754098360655737 0.9347826086956522\n",
      "7 1600 1.3138395547866821 0.9607843137254902 1.0\n",
      "7 2000 1.371401309967041 0.900990099009901 0.77\n",
      "7 2400 1.3110793828964233 0.9663865546218487 0.9682539682539683\n",
      "8 0 1.2741575241088867 1.0 1.0\n",
      "8 400 1.3567185401916504 0.9176470588235294 0.9107142857142857\n",
      "8 800 1.336592197418213 0.9376693766937669 0.9\n",
      "8 1200 1.3247429132461548 0.9493670886075949 0.9166666666666666\n",
      "8 1600 1.2806460857391357 0.9933774834437086 1.0\n",
      "8 2000 1.3589357137680054 0.9148936170212766 0.7647058823529411\n",
      "8 2400 1.2855772972106934 0.9915966386554622 0.9827586206896551\n",
      "9 0 1.3126474618911743 0.9622641509433962 0.8947368421052632\n",
      "9 400 1.3124134540557861 0.9615384615384616 0.8048780487804879\n",
      "9 800 1.2934716939926147 0.9818181818181818 0.9333333333333333\n",
      "9 1200 1.31006920337677 0.964824120603015 0.9846153846153847\n",
      "9 1600 1.2789678573608398 0.9953775038520801 0.9558823529411765\n",
      "9 2000 1.2742353677749634 1.0 1.0\n",
      "9 2400 1.319138765335083 0.9532163742690059 0.8507462686567164\n"
     ]
    }
   ],
   "source": [
    "model.fine_tuning(False)\n",
    "print(sum(p.numel() for p in model.parameters()) / 10000)\n",
    "train(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "bc02d6dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10581.7352\n",
      "0 0 1.281691074371338 0.99581589958159 0.9818181818181818\n",
      "0 400 1.4263242483139038 0.84375 0.7931034482758621\n",
      "0 800 1.3035809993743896 0.9774436090225563 0.9361702127659575\n",
      "0 1200 1.4049749374389648 0.861878453038674 0.576271186440678\n",
      "0 1600 1.3129820823669434 0.961038961038961 0.9464285714285714\n",
      "0 2000 1.2954390048980713 0.9794520547945206 0.9577464788732394\n",
      "0 2400 1.326935052871704 0.9532710280373832 0.9411764705882353\n",
      "1 0 1.307201862335205 0.9702970297029703 0.9782608695652174\n",
      "1 400 1.4246934652328491 0.8504672897196262 0.7837837837837838\n",
      "1 800 1.2742036581039429 1.0 1.0\n",
      "1 1200 1.286851406097412 0.9873417721518988 0.9814814814814815\n",
      "1 1600 1.5959538221359253 0.664 0.8620689655172413\n",
      "1 2000 1.2948931455612183 0.9776785714285714 0.9074074074074074\n",
      "1 2400 1.2933577299118042 0.9789915966386554 0.8979591836734694\n",
      "2 0 1.3438622951507568 0.9298780487804879 0.7529411764705882\n",
      "2 400 1.2742880582809448 1.0 1.0\n",
      "2 800 1.339837670326233 0.9343065693430657 0.8615384615384616\n",
      "2 1200 1.287752628326416 0.9840637450199203 0.9354838709677419\n",
      "2 1600 1.3434927463531494 0.93048128342246 0.8\n",
      "2 2000 1.3013510704040527 0.9727463312368972 0.803030303030303\n",
      "2 2400 1.3032867908477783 0.9719298245614035 0.8695652173913043\n",
      "3 0 1.2763350009918213 1.0 1.0\n",
      "3 400 1.2830088138580322 0.9923076923076923 1.0\n",
      "3 800 1.383424997329712 0.8913043478260869 0.8507462686567164\n",
      "3 1200 1.2756723165512085 1.0 1.0\n",
      "3 1600 1.3253273963928223 0.953125 0.9285714285714286\n",
      "3 2000 1.28341543674469 1.0 1.0\n",
      "3 2400 1.3402801752090454 0.9337748344370861 0.9230769230769231\n",
      "4 0 1.3250752687454224 0.9468085106382979 0.84375\n",
      "4 400 1.3250385522842407 0.9488636363636364 0.85\n",
      "4 800 1.2740148305892944 1.0 1.0\n",
      "4 1200 1.2897316217422485 0.9866071428571429 0.9516129032258065\n",
      "4 1600 1.3051677942276 0.9689922480620154 0.9459459459459459\n",
      "4 2000 1.3222603797912598 0.948905109489051 0.9183673469387755\n",
      "4 2400 1.3070043325424194 0.9695652173913043 0.9054054054054054\n",
      "5 0 1.3989176750183105 0.875 0.7619047619047619\n",
      "5 400 1.3313510417938232 0.9424083769633508 0.8269230769230769\n",
      "5 800 1.2875802516937256 0.9942528735632183 1.0\n",
      "5 1200 1.346081018447876 0.9279279279279279 0.84\n",
      "5 1600 1.3008015155792236 0.9732142857142857 0.9508196721311475\n",
      "5 2000 1.3422363996505737 0.9316239316239316 0.9230769230769231\n",
      "5 2400 1.4447122812271118 0.8285714285714286 0.7272727272727273\n",
      "6 0 1.2863551378250122 0.9876543209876543 0.98\n",
      "6 400 1.3598307371139526 0.9183673469387755 0.8867924528301887\n",
      "6 800 1.2924754619598389 0.9833333333333333 0.9285714285714286\n",
      "6 1200 1.3435945510864258 0.9304347826086956 0.8596491228070176\n",
      "6 1600 1.2741714715957642 1.0 1.0\n",
      "6 2000 1.3558440208435059 0.9173553719008265 0.9423076923076923\n",
      "6 2400 1.312970519065857 0.961038961038961 0.8732394366197183\n",
      "7 0 1.275989055633545 1.0 1.0\n",
      "7 400 1.2773332595825195 1.0 1.0\n",
      "7 800 1.2773927450180054 1.0 1.0\n",
      "7 1200 1.2741137742996216 1.0 1.0\n",
      "7 1600 1.3123936653137207 0.9705882352941176 0.9473684210526315\n",
      "7 2000 1.2740108966827393 1.0 1.0\n",
      "7 2400 1.2740191221237183 1.0 1.0\n",
      "8 0 1.274038553237915 1.0 1.0\n",
      "8 400 1.2975566387176514 0.9762931034482759 0.7843137254901961\n",
      "8 800 1.2903560400009155 0.9841772151898734 1.0\n",
      "8 1200 1.321557641029358 0.9537815126050421 0.7962962962962963\n",
      "8 1600 1.29928719997406 0.9747899159663865 0.9333333333333333\n",
      "8 2000 1.2740180492401123 1.0 1.0\n",
      "8 2400 1.29728102684021 0.9764150943396226 0.9555555555555556\n",
      "9 0 1.2938380241394043 0.98 0.9375\n",
      "9 400 1.3054555654525757 0.9685534591194969 0.9180327868852459\n",
      "9 800 1.288570523262024 0.9855072463768116 0.9\n",
      "9 1200 1.2743858098983765 1.0 1.0\n",
      "9 1600 1.3125838041305542 0.961456102783726 0.8\n",
      "9 2000 1.2742619514465332 1.0 1.0\n",
      "9 2400 1.2957044839859009 0.9792387543252595 0.8909090909090909\n",
      "10 0 1.3530429601669312 0.9208333333333333 0.7611940298507462\n",
      "10 400 1.324249505996704 0.9519230769230769 0.8771929824561403\n",
      "10 800 1.305650234222412 0.9710982658959537 0.9193548387096774\n",
      "10 1200 1.2740187644958496 1.0 1.0\n",
      "10 1600 1.3138805627822876 0.9603174603174603 0.9193548387096774\n",
      "10 2000 1.3143413066864014 0.9596774193548387 0.9295774647887324\n",
      "10 2400 1.301513671875 0.9725085910652921 0.8904109589041096\n",
      "11 0 1.3803668022155762 0.8935361216730038 0.7021276595744681\n",
      "11 400 1.3702261447906494 0.9041095890410958 0.8333333333333334\n",
      "11 800 1.2993906736373901 0.9746376811594203 0.9259259259259259\n",
      "11 1200 1.2810218334197998 0.9931506849315068 0.9821428571428571\n",
      "11 1600 1.29237961769104 0.9842931937172775 0.958904109589041\n",
      "11 2000 1.3960182666778564 0.8780487804878049 0.7818181818181819\n",
      "11 2400 1.3076266050338745 0.9663865546218487 0.92\n",
      "12 0 1.2740588188171387 1.0 1.0\n",
      "12 400 1.373016357421875 0.9016393442622951 0.8636363636363636\n",
      "12 800 1.305710792541504 0.9678899082568807 0.9166666666666666\n",
      "12 1200 1.2821658849716187 0.9921875 1.0\n",
      "12 1600 1.3136026859283447 0.9573170731707317 0.9682539682539683\n",
      "12 2000 1.291735291481018 0.9822784810126582 0.8970588235294118\n",
      "12 2400 1.2740267515182495 1.0 1.0\n",
      "13 0 1.27499520778656 1.0 1.0\n",
      "13 400 1.2740483283996582 1.0 1.0\n",
      "13 800 1.3006969690322876 0.9739130434782609 0.9642857142857143\n",
      "13 1200 1.2831953763961792 0.9905660377358491 0.9642857142857143\n",
      "13 1600 1.295846700668335 0.972972972972973 0.9420289855072463\n",
      "13 2000 1.3386294841766357 0.9359430604982206 0.7662337662337663\n",
      "13 2400 1.3042057752609253 0.9701492537313433 0.9310344827586207\n",
      "14 0 1.278923749923706 0.9957081545064378 0.98\n",
      "14 400 1.310533881187439 0.9637305699481865 0.8939393939393939\n",
      "14 800 1.2740105390548706 1.0 1.0\n",
      "14 1200 1.2742260694503784 1.0 1.0\n",
      "14 1600 1.2753076553344727 1.0 1.0\n",
      "14 2000 1.314448595046997 0.9596774193548387 0.9206349206349206\n",
      "14 2400 1.3182445764541626 0.9567567567567568 0.9508196721311475\n"
     ]
    }
   ],
   "source": [
    "model.fine_tuning(True)\n",
    "print(sum(p.numel() for p in model.parameters()) / 10000)\n",
    "train(15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ce2b2bff",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at d:\\暂存\\ResumeSystem\\ner\\data\\validation\\cache-26bc23ad408c2b6f.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "0.9970202622169249 0.9840819886611426\n"
     ]
    }
   ],
   "source": [
    "def test():\n",
    "    model_load = torch.load('ner2.model')\n",
    "    model_load.eval()\n",
    "\n",
    "    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),\n",
    "                                              batch_size=128,\n",
    "                                              collate_fn=collate_fn,\n",
    "                                              shuffle=True,\n",
    "                                              drop_last=True)\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    correct_content = 0\n",
    "    total_content = 0\n",
    "\n",
    "    for step, (inputs, labels) in enumerate(loader_test):\n",
    "        if step == 5:\n",
    "            break\n",
    "        print(step)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            #[b, lens] -> [b, lens, 8] -> [b, lens]\n",
    "            outs = model_load(inputs)\n",
    "\n",
    "        #对outs和label变形,并且移除pad\n",
    "        #outs -> [b, lens, 8] -> [c, 8]\n",
    "        #labels -> [b, lens] -> [c]\n",
    "        outs, labels = reshape_and_remove_pad(outs, labels,\n",
    "                                              inputs['attention_mask'])\n",
    "\n",
    "        counts = get_correct_and_total_count(labels, outs)\n",
    "        correct += counts[0]\n",
    "        total += counts[1]\n",
    "        correct_content += counts[2]\n",
    "        total_content += counts[3]\n",
    "\n",
    "    print(correct / total, correct_content / total_content)\n",
    "\n",
    "\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "372f735b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def custom_predict(sentence):\n",
    "\n",
    "    model_load = torch.load('ner5.model')\n",
    "    model_load = model_load.to(device)\n",
    "    model_load.eval()\n",
    "\n",
    "    inputs = tokenizer.encode_plus([sentence],\n",
    "                                   truncation=True,\n",
    "                                   padding=True,\n",
    "                                   return_tensors='pt',\n",
    "                                   is_split_into_words=True).to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs = model_load(inputs)\n",
    "\n",
    "    # [b, lens] -> [b, lens]\n",
    "    preds = outputs.argmax(dim=2)[0]\n",
    "\n",
    "    result = ''\n",
    "\n",
    "    res = [set() for _ in range(3)]\n",
    "    tmp = ''\n",
    "    current_flag = -1\n",
    "\n",
    "    for i in range(len(preds)):\n",
    "        if inputs['attention_mask'][0][i] == 1:\n",
    "            result += tokenizer.decode(inputs['input_ids'][0][i])+' '\n",
    "            result += str(preds[i].item())+' '\n",
    "            num = preds[i].item()\n",
    "            # num 不为0和7和#表示该词为关键词\n",
    "            if (num != 0 and num != 7 and num != '#'):\n",
    "                # 关键词开始为奇数\n",
    "                if (num & 1):\n",
    "                    # 将形如 广5东6广5州6 拆成两个词\n",
    "                    if (len(tmp) > 1):\n",
    "                        res[(current_flag-1)//2].add(tmp)\n",
    "                        tmp = ''\n",
    "                    current_flag = num\n",
    "                    tmp += tokenizer.decode(inputs['input_ids'][0][i])\n",
    "\n",
    "                    # 防止形如 X4X4 出现\n",
    "                elif (num & 1 == 0 and current_flag != -1):\n",
    "                    tmp += tokenizer.decode(inputs['input_ids'][0][i])\n",
    "            else:\n",
    "                if (len(tmp) > 1):\n",
    "                    # current_flag 1对应姓名下标0，3对应组织下表1，5对应地点下表2\n",
    "                    res[(current_flag-1)//2].add(tmp)\n",
    "                    tmp = ''\n",
    "                    current_flag = -1\n",
    "\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "5614cd38",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "姓名[],组织['统战统战部'],地点[]\n"
     ]
    }
   ],
   "source": [
    "input_sentence = \"李白在上海的华东理工大学读书\"\n",
    "output_prediction = custom_predict(input_sentence)\n",
    "print(\n",
    "    f'姓名{list(output_prediction[0])},组织{list(output_prediction[1])},地点{list(output_prediction[2])}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
